
#include "commonCuda/rtweekend.cuh"
#include "commonCuda/camera.cuh"
#include "commonCuda/texture.cuh"
#include "sphere.cuh"
#include "moving_sphere.cuh"
#include "hittable_list.cuh"
#include "material.cuh"

#include <iostream>
#include <time.h>
#include <float.h>
#include <curand_kernel.h>

/* Note, doing a straight translation from the original C++ will mean that any
   floating-point constants will be doubles and math on the GPU will be forced
   to be double-precision.  This will hurt our performance unnecessarily.
   Special attention to floating point constants must be taken (e.g. 0.5 ->
   0.5f). */

// limited version of checkCudaErrors from helper_cuda.h in CUDA examples
#define checkCudaErrors(val) check_cuda( (val), #val, __FILE__, __LINE__ )

void check_cuda(cudaError_t result, char const *const func, const char *const file, int const line) {
  if (result) {
    std::cerr << "CUDA error = " << static_cast<unsigned int>(result) << " at " <<
      file << ":" << line << " '" << func << "' \n";
    // Make sure we call CUDA Device Reset before exiting
    cudaDeviceReset();
    exit(99);
  }
}

// Matching the C++ code would recurse enough into color() calls that
// it was blowing up the stack, so we have to turn this into a
// limited-depth loop instead.  Later code in the book limits to a max
// depth of 50, so we adapt this a few chapters early on the GPU.
__device__ vec3 color(const ray& r, hittable **world, curandState *local_rand_state) {
  ray cur_ray = r;
  vec3 cur_attenuation = vec3(1.0,1.0,1.0);
  for(int i = 0; i < 50; i++) {
    hit_record rec;
    if ((*world)->hit(cur_ray, 0.001f, FLT_MAX, rec)) {
      ray scattered;
      vec3 attenuation;
      if(rec.mat_ptr->scatter(cur_ray, rec, attenuation, scattered, local_rand_state)) {
        cur_attenuation *= attenuation;
        cur_ray = scattered;
      }
      else {
        return vec3(0.0,0.0,0.0);
      }
    }
    else {
      vec3 unit_direction = unit_vector(cur_ray.direction());
      float t = 0.5f*(unit_direction.y() + 1.0f);
      vec3 c = (1.0f-t)*vec3(1.0, 1.0, 1.0) + t*vec3(0.5, 0.7, 1.0);
      return cur_attenuation * c;
    }
  }
  return vec3(0.0,0.0,0.0); // exceeded recursion
}

__global__ void rand_init(curandState *rand_state) {
  if (threadIdx.x == 0 && blockIdx.x == 0) {
    curand_init(1984, 0, 0, rand_state);
  }
}

__global__ void render_init(int max_x, int max_y, curandState *rand_state) {
  int i = threadIdx.x + blockIdx.x * blockDim.x;
  int j = threadIdx.y + blockIdx.y * blockDim.y;
  if((i >= max_x) || (j >= max_y)) return;
  int pixel_index = j*max_x + i;
  //Each thread gets same seed, a different sequence number, no offset
  curand_init(1984, pixel_index, 0, &rand_state[pixel_index]);
}

__global__ void render(vec3 *fb, int max_x, int max_y, int ns, camera **cam, hittable **world, curandState *rand_state) {
  int i = threadIdx.x + blockIdx.x * blockDim.x;
  int j = threadIdx.y + blockIdx.y * blockDim.y;
  if((i >= max_x) || (j >= max_y)) return;
  int pixel_index = j*max_x + i;
  curandState local_rand_state = rand_state[pixel_index];
  vec3 col(0,0,0);
  for(int s=0; s < ns; s++) {
    float u = float(i + curand_uniform(&local_rand_state)) / float(max_x);
    float v = float(j + curand_uniform(&local_rand_state)) / float(max_y);
    ray r = (*cam)->get_ray(u, v, &local_rand_state);
    col += color(r, world, &local_rand_state);
  }
  rand_state[pixel_index] = local_rand_state;
  col /= float(ns);
  col[0] = sqrt(col[0]);
  col[1] = sqrt(col[1]);
  col[2] = sqrt(col[2]);
  fb[pixel_index] = col;
}

#define RND (curand_uniform(&local_rand_state))

__global__ void create_world(hittable **d_list, hittable **d_world, camera **d_camera, int nx, int ny, curandState *rand_state) {
  if (threadIdx.x == 0 && blockIdx.x == 0) {
    curandState local_rand_state = *rand_state;
    // d_list[0] = new sphere(vec3(0,-1000.0,-1), 1000,
    //                        new lambertian(vec3(0.5, 0.5, 0.5)));
    Texture *checker = new checker_texture(
                                           new constant_texture(vec3(0.2, 0.3, 0.1)),
                                           new constant_texture(vec3(0.9, 0.9, 0.9))
                                           );
    d_list[0] = new moving_sphere(vec3(0,-1000.0,-1), vec3(0,-1000.0,-1),
                                  0.f, 1.f,
                                  1000,
                                  new lambertian(checker));
    // d_list[0] = new moving_sphere(vec3(0,-1000.0,-1), vec3(0,-1000.0,-1),
    //                               0.f, 1.f,
    //                               1000,
    //                               new lambertian(new constant_texture(vec3(0.5, 0.5, 0.5))));
    int i = 1;
    for(int a = -11; a < 11; a++) {
      for(int b = -11; b < 11; b++) {
        float choose_mat = RND;
        vec3 center(a+RND,0.2,b+RND);
        if(choose_mat < 0.8f) {
          // d_list[i++] = new sphere(center, 0.2,
          //                          new lambertian(vec3(RND*RND, RND*RND, RND*RND)));
          d_list[i++] = new moving_sphere(center, center+vec3(0, 0.5*RND, 0),
                                          0.f, 1.f,
                                          0.2,
                                          new lambertian(new constant_texture(vec3(RND*RND, RND*RND, RND*RND))));
        }
        else if(choose_mat < 0.95f) {
          // d_list[i++] = new sphere(center, 0.2,
          //                          new metal(vec3(0.5f*(1.0f+RND), 0.5f*(1.0f+RND), 0.5f*(1.0f+RND)), 0.5f*RND));
          d_list[i++] = new moving_sphere(center, center,
                                   0.f, 1.f,
                                   0.2,
                                   new metal(vec3(0.5f*(1.0f+RND), 0.5f*(1.0f+RND), 0.5f*(1.0f+RND)), 0.5f*RND));
        }
        else {
          //d_list[i++] = new sphere(center, 0.2, new dielectric(1.5));
          d_list[i++] = new moving_sphere(center, center, 0.f, 1.f, 0.2, new dielectric(1.5));
        }
      }
    }
    // d_list[i++] = new sphere(vec3(0, 1,0),  1.0, new dielectric(1.5));
    // d_list[i++] = new sphere(vec3(-4, 1, 0), 1.0, new lambertian(vec3(0.4, 0.2, 0.1)));
    // d_list[i++] = new sphere(vec3(4, 1, 0),  1.0, new metal(vec3(0.7, 0.6, 0.5), 0.0));
    d_list[i++] = new moving_sphere(vec3(0, 1,0),  vec3(0, 1,0),   0.f, 1.f, 1.0, new dielectric(1.5));
    d_list[i++] = new moving_sphere(vec3(-4, 1, 0),vec3(-4, 1, 0), 0.f, 1.f, 1.0,
                                    new lambertian(new constant_texture(vec3(0.4, 0.2, 0.1))));
    d_list[i++] = new moving_sphere(vec3(4, 1, 0), vec3(4, 1, 0),  0.f, 1.f, 1.0, new metal(vec3(0.7, 0.6, 0.5), 0.0));
    *rand_state = local_rand_state;
    *d_world  = new hittable_list(d_list, 22*22+1+3);

    vec3 lookfrom(13,2,3);
    vec3 lookat(0,0,0);
    float dist_to_focus = 10.0; (lookfrom-lookat).length();
    //float aperture = 0.1f;
    float aperture = 0.0f;
    *d_camera   = new camera(lookfrom,
                             lookat,
                             vec3(0,1,0),
                             30.0,
                             float(nx)/float(ny),
                             aperture,
                             dist_to_focus,
                             0.f, 1.f);
  }
}

__global__ void free_world(hittable **d_list, hittable **d_world, camera **d_camera) {
  for(int i=0; i < 22*22+1+3; i++) {
    //delete ((sphere *)d_list[i])->mat_ptr;
    delete ((moving_sphere *)d_list[i])->mat_ptr;
    delete d_list[i];
  }
  delete *d_world;
  delete *d_camera;
}

int main (int argc, char** argv) {

  // default values
  bool SUPER_QUALITY_RENDER = !true;
  bool HIGH_QUALITY_RENDER = !true;
  bool MEDIUM_QUALITY_RENDER = !true;
  bool PROFILE_RENDER = !true;

  // handle command line arguments
  if (argc >= 2) {
    // first command line argument is "SH"?
    if (std::string(argv[1]) == "PR") {
      PROFILE_RENDER = true;
    }
    // first command line argument is "SH"?
    if (std::string(argv[1]) == "SH") {
      SUPER_QUALITY_RENDER = true;
    }
    // first command line argument is "HQ"?
    if (std::string(argv[1]) == "HQ") {
      HIGH_QUALITY_RENDER = true;
    }
    // first command line argument is "MQ"?
    if (std::string(argv[1]) == "MQ") {
      MEDIUM_QUALITY_RENDER = true;
    }
  }

  int nx, ny, ns;
  int tx = 8;
  int ty = 8;

  if (PROFILE_RENDER) {
    nx = tx*8;
    ny = ty*4;
    ns = 10;
  } else if (SUPER_QUALITY_RENDER) {
    nx = 600;
    ny = 400;
    ns = 100;
    nx *= 2; ny *= 2;
    ns /= 2;
  } else if (HIGH_QUALITY_RENDER) {
    nx = 600;
    ny = 400;
    ns = 100;
  } else if (MEDIUM_QUALITY_RENDER) {
    nx = 1200;
    ny = 800;
    ns = 20;
  } else {
    nx = 1200;
    ny = 800;
    ns = 10;
  }


  std::cerr << "Rendering a " << nx << "x" << ny << " image with " << ns << " samples per pixel ";
  std::cerr << "in " << tx << "x" << ty << " blocks.\n";

  int num_pixels = nx*ny;
  size_t fb_size = num_pixels*sizeof(vec3);

  // allocate FB
  vec3 *fb;
  checkCudaErrors(cudaMallocManaged((void **)&fb, fb_size));

  // allocate random state
  curandState *d_rand_state;
  checkCudaErrors(cudaMalloc((void **)&d_rand_state, num_pixels*sizeof(curandState)));
  curandState *d_rand_state2;
  checkCudaErrors(cudaMalloc((void **)&d_rand_state2, 1*sizeof(curandState)));

  // we need that 2nd random state to be initialized for the world creation
  rand_init<<<1,1>>>(d_rand_state2);
  checkCudaErrors(cudaGetLastError());
  checkCudaErrors(cudaDeviceSynchronize());

  // make our world of hittables & the camera
  hittable **d_list;
  int num_hittables = 22*22+1+3;
  checkCudaErrors(cudaMalloc((void **)&d_list, num_hittables*sizeof(hittable *)));
  hittable **d_world;
  checkCudaErrors(cudaMalloc((void **)&d_world, sizeof(hittable *)));
  camera **d_camera;
  checkCudaErrors(cudaMalloc((void **)&d_camera, sizeof(camera *)));
  create_world<<<1,1>>>(d_list, d_world, d_camera, nx, ny, d_rand_state2);
  checkCudaErrors(cudaGetLastError());
  checkCudaErrors(cudaDeviceSynchronize());

  clock_t start, stop;
  start = clock();
  // Render our buffer
  dim3 blocks(nx/tx+1,ny/ty+1);
  dim3 threads(tx,ty);
  render_init<<<blocks, threads>>>(nx, ny, d_rand_state);
  checkCudaErrors(cudaGetLastError());
  checkCudaErrors(cudaDeviceSynchronize());
  render<<<blocks, threads>>>(fb, nx, ny,  ns, d_camera, d_world, d_rand_state);
  checkCudaErrors(cudaGetLastError());
  checkCudaErrors(cudaDeviceSynchronize()); // errors when profiling
  //cudaDeviceSynchronize(); // errors when profiling
  stop = clock();
  double timer_seconds = ((double)(stop - start)) / CLOCKS_PER_SEC;
  std::cerr << "took " << timer_seconds << " seconds.\n";

  // Output FB as Image
  std::cout << "P3\n" << nx << " " << ny << "\n255\n";
  for (int j = ny-1; j >= 0; j--) {
    for (int i = 0; i < nx; i++) {
      size_t pixel_index = j*nx + i;
      int ir = int(255.99*fb[pixel_index].r());
      int ig = int(255.99*fb[pixel_index].g());
      int ib = int(255.99*fb[pixel_index].b());
      std::cout << ir << " " << ig << " " << ib << "\n";
    }
  }

  // clean up
  checkCudaErrors(cudaDeviceSynchronize());  // errors in profiler
  //cudaDeviceSynchronize();
  free_world<<<1,1>>>(d_list,d_world,d_camera);
  checkCudaErrors(cudaGetLastError());
  //cudaGetLastError();
  checkCudaErrors(cudaFree(d_camera));
  checkCudaErrors(cudaFree(d_world));
  checkCudaErrors(cudaFree(d_list));
  checkCudaErrors(cudaFree(d_rand_state));
  checkCudaErrors(cudaFree(fb));

  cudaDeviceReset();
}
